import torch.nn as nn


class SegmentEmbedding(nn.Embedding):
    def __init__(self, embed_size=512):
        super().__init__(3, embed_size, padding_idx=0) # 3表示输入的第一个维度大小是3,即两个句子的segment label以及一个填充标签
